import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.lines as mlines
import seaborn as sns
import numpy as np
from string import punctuation
import scipy.stats as st


# Load data
df = pd.read_csv('../data/temp/sentiment_narratives.csv')
df['narrative'] = df.narrative.apply(lambda x: x.strip())
df['above_zero'] = df.sentiment_compound > 0

# Plot
_fig, _ax = plt.subplots(figsize=(20, 8))

_ax.xaxis.label.set_fontsize(25)
_ax.xaxis.labelpad = 20
_ax.set_xlim([-1, +1.1])
_ax.xaxis.set_ticks(np.arange(-1, 1.1, 0.2))


_ax.yaxis.label.set_fontsize(25)
#_ax.set_facecolor((.18, .31, .31, 0.0))
_ax.set_axisbelow(True)
_ax.axes.get_yaxis().set_visible(False)
_ax.axes.set_ylim([-5, 5])
_ax.spines['left'].set_visible(False)
_ax.spines['top'].set_visible(False)
_ax.spines['bottom'].set_visible(False)
_ax.spines['right'].set_visible(False)

markers = {True: "h", False: "v"}

_g = sns.scatterplot(data=df,
                     x= 'sentiment_compound',
                     y=[0]*20,
                     s=300,
                     style='above_zero',
                     markers=markers,
                     hue='above_zero',
                     palette=['orange', 'green'],
                     alpha=0.8,
                     zorder=2,
                     ax=_ax)

# Point Labels
# The amount that the labels should be placed according to the actual points (x, y)
label_positions_deviations = {
                                'government estimate contract': (-0.3, -3.5),
                                'god continue america': (-0.5, 0.8),
                                'god bless nation': (-0.2, -1.5),
                                'person love family': (-0.1, 3.2),
                                'constitution prohibit america': (-0.4, -2.7),
                                'god grant strong' : (-0.2, 1.7),
                                'family give afghanistan': (-0.1, -1.5),
                                'money amount taxable': (-0.1, 2.3),
                                'people promote welfare': (-0.1, -2.3),
                                'people establish justice': (0.0, 1.2),

                                # Positive sentiment_compound
                                'crime add category': (-0.4, -2.3),
                                'homeland attack unit': (-0.5, 0.9),
                                'law enforcement add category': (-0.3, -3.5),
                                'terrorist kill american': (-0.3, 3.5),
                                'saddam hussein use weapon mass destruction': (-0.1, -2.6),
                                'terrorist attack country': (-0.2, 2.5),
                                'iraq have weapon mass destruction': (-0.0, -1.2),
                                'saddam hussein have weapon mass destruction': (-0.3, 1.5),
                                'people kill american': (-0.2, -1.5),
                                'men women fight war': (-0.1, 3),
                           }

for i in range(df.shape[0]):
    narrative_text = df.narrative[i]
    _ax.annotate(s=df.narrative[i],
                    xy=(df.sentiment_compound[i], 0),
                    xytext=(df.sentiment_compound[i]+label_positions_deviations[narrative_text][0],
                            label_positions_deviations[narrative_text][1]),
                    arrowprops = dict(
                                        arrowstyle="simple",
                                        shrinkA=5,
                                        shrinkB=15,
                                        color='grey',
                                        alpha=0.2,
#                                      connectionstyle="angle3,angleA=0,angleB=-90"
                                      ),
                    color='black',
                    fontsize = 16,
#                    bbox=dict(pad=5,
#                              facecolor='blue',
#                              alpha=0.2)
                )



plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
_ax.set_axisbelow(True)
_ax.xaxis.grid(color='gray', alpha=0.3, linestyle='dashed')
_g.set(xlabel=None)
_ax.get_legend().remove()

# Output file
plt.savefig('../figures/Figure_3_b.pdf', bbox_inches="tight")
